"""
Code to get optimized thresholds using yellowbrick library
"""
from sklearn.linear_model import LogisticRegression
from yellowbrick.classifier import DiscriminationThreshold
from yellowbrick.classifier.threshold import discrimination_threshold
import pandas as pd
from scipy.sparse import load_npz
import warnings


def warn(*args, **kwargs):
    # function to deal with deprecated error messages
    pass

warnings.warn = warn

X = load_npz('../data_clean/vectorized_text.npz')
y_all = pd.read_csv(('../data_clean/labels.csv'))[['S6_is_formal']]
label_cols = [
                'S6_is_formal', 'S6_is_legal', 'S6_is_technical',
                'S6_is_aggressive', 'S8_dummy_Activities', 'S8_dummy_Budget',
                'S8_dummy_Evaluation', 'S8_dummy_ExternalContracts',
                'S8_dummy_InstStruc', 'S8_dummy_Other', 'S8_dummy_Regulatory',
                'S9_dummy_Academic/Scholarly', 'S9_dummy_Commercial',
                'S9_dummy_Impossible to say', 'S9_dummy_Monitoring',
                'S9_dummy_Personal', 'S10_is_clear',
                'S10_is_competency_of_institution', 'S10_is_public',
                'S10_is_existant', 'S11_dummy_Date',
                'S11_dummy_Document', 'S11_dummy_Institution',
                'S11_dummy_Organization', 'S11_dummy_Person', 'S11_dummy_Place'
                ]


for k, name in enumerate(label_cols):
    model = LogisticRegression(C=1,
                               solver='lbfgs', max_iter=500,
                               multi_class='auto')
    visualizer = DiscriminationThreshold(model)
    print(name)
    discrimination_threshold(
        model, X[:3936], y_all.values[:3936, k], cvfloat=0.2, exclude="queue_rate")